{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers import Embedding, TimeDistributed, RepeatVector, LSTM, concatenate , Input, Reshape, Dense\n",
    "from keras.preprocessing.image import array_to_img, img_to_array, load_img\n",
    "import numpy as np\n",
    "from keras.applications.vgg16 import VGG16, preprocess_input\n",
    "from keras.models import Model\n",
    "from IPython.core.display import display, HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Lenght of longest sentence\n",
    "max_caption_len = 3\n",
    "#Size of vocabulary \n",
    "vocab_size = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load one screenshot for each word, and turn them into digits \n",
    "images = []\n",
    "for i in range(2):\n",
    "    images.append(img_to_array(load_img('screenshot.jpg', target_size=(224, 224))))\n",
    "images = np.array(images, dtype=float)\n",
    "# Preprocess input for the VGG16 model\n",
    "images = preprocess_input(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Turn start tokens into one-hot encoding\n",
    "html_input = np.array(\n",
    "            [[[0., 0., 0.], #start\n",
    "             [0., 0., 0.],\n",
    "             [1., 0., 0.]],\n",
    "             [[0., 0., 0.], #start <HTML>Hello World!</HTML>\n",
    "             [1., 0., 0.],\n",
    "             [0., 1., 0.]]])\n",
    "\n",
    "#Turn next word into one-hot encoding\n",
    "next_words = np.array(\n",
    "            [[0., 1., 0.], # <HTML>Hello World!</HTML>\n",
    "             [0., 0., 1.]]) # end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the VGG16 model trained on imagenet and output the classification feature\n",
    "VGG = VGG16(weights=None, include_top=True)\n",
    "VGG.load_weights('/data/models/vgg16_weights_tf_dim_ordering_tf_kernels.h5')\n",
    "# Extract the features from the image\n",
    "features = VGG.predict(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load the feature to the network, apply a dense layer, and repeat the vector\n",
    "vgg_feature = Input(shape=(1000,))\n",
    "vgg_feature_dense = Dense(5)(vgg_feature)\n",
    "vgg_feature_repeat = RepeatVector(max_caption_len)(vgg_feature_dense)\n",
    "# Extract information from the input seqence \n",
    "language_input = Input(shape=(vocab_size, vocab_size))\n",
    "language_model = LSTM(5, return_sequences=True)(language_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate the information from the image and the input\n",
    "decoder = concatenate([vgg_feature_repeat, language_model])\n",
    "# Extract information from the concatenated output\n",
    "decoder = LSTM(5, return_sequences=False)(decoder)\n",
    "# Predict which word comes next\n",
    "decoder_output = Dense(vocab_size, activation='softmax')(decoder)\n",
    "# Compile and run the neural network\n",
    "model = Model(inputs=[vgg_feature, language_input], outputs=decoder_output)\n",
    "model.compile(loss='categorical_crossentropy', optimizer='rmsprop')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the neural network\n",
    "model.fit([features, html_input], next_words, batch_size=2, shuffle=False, epochs=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_token = [1., 0., 0.] # start\n",
    "sentence = np.zeros((1, 3, 3)) # [[0,0,0], [0,0,0], [0,0,0]]\n",
    "sentence[0][2] = start_token # place start in empty sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Making the first prediction with the start token\n",
    "second_word = model.predict([np.array([features[1]]), sentence])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Put the second word in the sentence and make the final prediction\n",
    "sentence[0][1] = start_token\n",
    "sentence[0][2] = np.round(second_word)\n",
    "third_word = model.predict([np.array([features[1]]), sentence])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Place the start token and our two predictions in the sentence "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence[0][0] = start_token\n",
    "sentence[0][1] = np.round(second_word)\n",
    "sentence[0][2] = np.round(third_word)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transform our one-hot predictions into the final tokens\n",
    "vocabulary = [\"start\", \"<HTML><center><H1>Hello World!</H1><center></HTML>\", \"end\"]\n",
    "html = \"\"\n",
    "for i in sentence[0]:\n",
    "    html += vocabulary[np.argmax(i)] + ' '"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(HTML(html[6:49]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
